{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Getting MNIST Dataset...\n",
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "Data Extracted.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "print('Getting MNIST Dataset...')\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)\n",
    "print('Data Extracted.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def conv_layer(input, size_in, size_out, name=\"conv\"):\n",
    "    # 5x5 kernel size\n",
    "    # size_in input depth\n",
    "    # size_out map feature count\n",
    "    w = tf.Variable(tf.truncated_normal([5,5,size_in,size_out], stddev=0.1), name='w')\n",
    "    b = tf.Variable(tf.constant(0.1, shape=[size_in, size_out]), name='b')\n",
    "    conv = tf.nn.conv2d(input, w, strides=[1,1,1,1], padding='SAME')\n",
    "    activation = tf.nn.relu(conv + b)\n",
    "    return tf.nn.max_pool(activation, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def fc_layer(input, size_in, size_out, name='fc'):\n",
    "    w = tf.Variable(tf.truncated_normal([size_in, size_out], stddev=0.1), name='w')\n",
    "    b = tf.Variable(tf.constant(0.1, shape=[size_out], name='b'))\n",
    "    activation = tf.nn.relu(tf.matmul(input, w) + b)\n",
    "    return activation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def mnist_model(learning_rate, use_two_conv, use_two_fc, hparam):\n",
    "    tf.reset_default_graph()\n",
    "    sess = tf.Session()\n",
    "    x =  tf.placeholder(tf.float32, shape=[None, 784], name='x')\n",
    "    # 把图片reshape成conv需要的尺寸，正方形\n",
    "    x_image = tf.reshape(x, [-1, 28, 28, 1])\n",
    "    \n",
    "    y = tf.placeholder(tf.float32, shape=[None, 10], name='labels')\n",
    "    \n",
    "    if use_two_conv:\n",
    "        conv1 = conv_layer(x_image, 1, 32, \"conv1\")\n",
    "        conv_out = conv_layer(conv1, 32, 64, 'conv2')\n",
    "    else:\n",
    "        conv1 = conv_layer(x_image, 1, 64, 'conv')\n",
    "        conv_out = tf.nn.max_pool(conv1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')\n",
    "    \n",
    "    # 28x28 5x5 pool-> 14x14 pool -> 7x7 max pool 两次\n",
    "    flattened = tf.reshape(conv_out, shape=[-1, 7*7*64])\n",
    "    \n",
    "    if use_two_conv:\n",
    "        fc1 = fc_layer(flattened, 7 * 7 * 64, 1024, 'fc1')\n",
    "        logits = fc_layer(fc1, 1024, 10, 'fc2')\n",
    "    else:\n",
    "        logits = fc_layer(flattened, 7 * 7 * 64 , 10, 'fc')\n",
    "    \n",
    "    xent = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y), name='xent')\n",
    "    train_step = tf.train.AdamOptimizer(learning_rate).minimize(xent)\n",
    "    \n",
    "    correction_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))\n",
    "    \n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    for i in range(2001):\n",
    "        batch = mnist.train.next_batch(100)\n",
    "        if i % 50 == 0:\n",
    "            acc = sess.run(accuracy, feed_dict={x:batch[0], y:batch[1]})\n",
    "            print(\"accuracy : \", acc)\n",
    "        sess.run(train_step, feed_dict={x:batch[0], y:batch[1]})        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def make_hparam_string(learning_rate, use_two_fc, use_two_conv):\n",
    "    conv_param = \"conv=2\" if use_two_conv else \"conv=1\"\n",
    "    fc_param= \"fc=2\" if use_two_fc else \"fc=1\"\n",
    "    return \"lr_%.0E,%s,%s\" % (learning_rate, conv_param, fc_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy :  0.1\n",
      "accuracy :  0.06\n",
      "accuracy :  0.08\n",
      "accuracy :  0.06\n",
      "accuracy :  0.13\n",
      "accuracy :  0.11\n",
      "accuracy :  0.08\n",
      "accuracy :  0.13\n",
      "accuracy :  0.06\n",
      "accuracy :  0.02\n",
      "accuracy :  0.09\n",
      "accuracy :  0.14\n",
      "accuracy :  0.09\n",
      "accuracy :  0.08\n",
      "accuracy :  0.18\n",
      "accuracy :  0.12\n",
      "accuracy :  0.1\n",
      "accuracy :  0.06\n",
      "accuracy :  0.07\n",
      "accuracy :  0.06\n",
      "accuracy :  0.1\n",
      "accuracy :  0.09\n",
      "accuracy :  0.12\n",
      "accuracy :  0.08\n",
      "accuracy :  0.07\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-14-60d623b493a2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmnist_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.01\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-11-084a5826f96e>\u001b[0m in \u001b[0;36mmnist_model\u001b[0;34m(learning_rate, use_two_conv, use_two_fc, hparam)\u001b[0m\n\u001b[1;32m     36\u001b[0m             \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     37\u001b[0m             \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"accuracy : \"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m         \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/juwenz/anaconda3/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    765\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    766\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 767\u001b[0;31m                          run_metadata_ptr)\n\u001b[0m\u001b[1;32m    768\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    769\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/juwenz/anaconda3/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    963\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    964\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 965\u001b[0;31m                              feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m    966\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    967\u001b[0m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/juwenz/anaconda3/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1013\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1014\u001b[0m       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1015\u001b[0;31m                            target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m   1016\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1017\u001b[0m       return self._do_call(_prun_fn, self._session, handle, feed_dict,\n",
      "\u001b[0;32m/Users/juwenz/anaconda3/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m   1020\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1021\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1022\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1023\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1024\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/juwenz/anaconda3/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m   1002\u001b[0m         return tf_session.TF_Run(session, options,\n\u001b[1;32m   1003\u001b[0m                                  \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1004\u001b[0;31m                                  status, run_metadata)\n\u001b[0m\u001b[1;32m   1005\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1006\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "mnist_model(0.01, False, False, \"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TODO\n",
    "\n",
    "### 1 弄明白代码里面的参数意义\n",
    "### 2 引入tensorboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
